blob: 9bfd28f8cc775c50d178af4ff28436de4c06dcfa [file] [log] [blame]
Nick Terrellcc388d22018-03-30 12:14:53 -07001/*
2 * Cryptographic API.
3 *
4 * Copyright (c) 2017-present, Facebook, Inc.
5 *
6 * This program is free software; you can redistribute it and/or modify it
7 * under the terms of the GNU General Public License version 2 as published by
8 * the Free Software Foundation.
9 *
10 * This program is distributed in the hope that it will be useful, but WITHOUT
11 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
12 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
13 * more details.
14 */
15#include <linux/crypto.h>
16#include <linux/init.h>
17#include <linux/interrupt.h>
18#include <linux/mm.h>
19#include <linux/module.h>
20#include <linux/net.h>
21#include <linux/vmalloc.h>
22#include <linux/zstd.h>
23
24
25#define ZSTD_DEF_LEVEL 3
26
27struct zstd_ctx {
28 ZSTD_CCtx *cctx;
29 ZSTD_DCtx *dctx;
30 void *cwksp;
31 void *dwksp;
32};
33
34static ZSTD_parameters zstd_params(void)
35{
36 return ZSTD_getParams(ZSTD_DEF_LEVEL, 0, 0);
37}
38
39static int zstd_comp_init(struct zstd_ctx *ctx)
40{
41 int ret = 0;
42 const ZSTD_parameters params = zstd_params();
43 const size_t wksp_size = ZSTD_CCtxWorkspaceBound(params.cParams);
44
45 ctx->cwksp = vzalloc(wksp_size);
46 if (!ctx->cwksp) {
47 ret = -ENOMEM;
48 goto out;
49 }
50
51 ctx->cctx = ZSTD_initCCtx(ctx->cwksp, wksp_size);
52 if (!ctx->cctx) {
53 ret = -EINVAL;
54 goto out_free;
55 }
56out:
57 return ret;
58out_free:
59 vfree(ctx->cwksp);
60 goto out;
61}
62
63static int zstd_decomp_init(struct zstd_ctx *ctx)
64{
65 int ret = 0;
66 const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
67
68 ctx->dwksp = vzalloc(wksp_size);
69 if (!ctx->dwksp) {
70 ret = -ENOMEM;
71 goto out;
72 }
73
74 ctx->dctx = ZSTD_initDCtx(ctx->dwksp, wksp_size);
75 if (!ctx->dctx) {
76 ret = -EINVAL;
77 goto out_free;
78 }
79out:
80 return ret;
81out_free:
82 vfree(ctx->dwksp);
83 goto out;
84}
85
86static void zstd_comp_exit(struct zstd_ctx *ctx)
87{
88 vfree(ctx->cwksp);
89 ctx->cwksp = NULL;
90 ctx->cctx = NULL;
91}
92
93static void zstd_decomp_exit(struct zstd_ctx *ctx)
94{
95 vfree(ctx->dwksp);
96 ctx->dwksp = NULL;
97 ctx->dctx = NULL;
98}
99
100static int __zstd_init(void *ctx)
101{
102 int ret;
103
104 ret = zstd_comp_init(ctx);
105 if (ret)
106 return ret;
107 ret = zstd_decomp_init(ctx);
108 if (ret)
109 zstd_comp_exit(ctx);
110 return ret;
111}
112
113static int zstd_init(struct crypto_tfm *tfm)
114{
115 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
116
117 return __zstd_init(ctx);
118}
119
120static void __zstd_exit(void *ctx)
121{
122 zstd_comp_exit(ctx);
123 zstd_decomp_exit(ctx);
124}
125
126static void zstd_exit(struct crypto_tfm *tfm)
127{
128 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
129
130 __zstd_exit(ctx);
131}
132
133static int __zstd_compress(const u8 *src, unsigned int slen,
134 u8 *dst, unsigned int *dlen, void *ctx)
135{
136 size_t out_len;
137 struct zstd_ctx *zctx = ctx;
138 const ZSTD_parameters params = zstd_params();
139
140 out_len = ZSTD_compressCCtx(zctx->cctx, dst, *dlen, src, slen, params);
141 if (ZSTD_isError(out_len))
142 return -EINVAL;
143 *dlen = out_len;
144 return 0;
145}
146
147static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
148 unsigned int slen, u8 *dst, unsigned int *dlen)
149{
150 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
151
152 return __zstd_compress(src, slen, dst, dlen, ctx);
153}
154
155static int __zstd_decompress(const u8 *src, unsigned int slen,
156 u8 *dst, unsigned int *dlen, void *ctx)
157{
158 size_t out_len;
159 struct zstd_ctx *zctx = ctx;
160
161 out_len = ZSTD_decompressDCtx(zctx->dctx, dst, *dlen, src, slen);
162 if (ZSTD_isError(out_len))
163 return -EINVAL;
164 *dlen = out_len;
165 return 0;
166}
167
168static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
169 unsigned int slen, u8 *dst, unsigned int *dlen)
170{
171 struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
172
173 return __zstd_decompress(src, slen, dst, dlen, ctx);
174}
175
176static struct crypto_alg alg = {
177 .cra_name = "zstd",
178 .cra_flags = CRYPTO_ALG_TYPE_COMPRESS,
179 .cra_ctxsize = sizeof(struct zstd_ctx),
180 .cra_module = THIS_MODULE,
181 .cra_init = zstd_init,
182 .cra_exit = zstd_exit,
183 .cra_u = { .compress = {
184 .coa_compress = zstd_compress,
185 .coa_decompress = zstd_decompress } }
186};
187
188static int __init zstd_mod_init(void)
189{
190 int ret;
191
192 ret = crypto_register_alg(&alg);
193 if (ret)
194 return ret;
195
196 return ret;
197}
198
199static void __exit zstd_mod_fini(void)
200{
201 crypto_unregister_alg(&alg);
202}
203
204module_init(zstd_mod_init);
205module_exit(zstd_mod_fini);
206
207MODULE_LICENSE("GPL");
208MODULE_DESCRIPTION("Zstd Compression Algorithm");
209MODULE_ALIAS_CRYPTO("zstd");