1// SPDX-License-Identifier: GPL-2.0-or-later
2#include <linux/zlib.h>
3#include "compress.h"
4
5struct z_erofs_deflate {
6	struct z_erofs_deflate *next;
7	struct z_stream_s z;
8	u8 bounce[PAGE_SIZE];
9};
10
11static DEFINE_SPINLOCK(z_erofs_deflate_lock);
12static unsigned int z_erofs_deflate_nstrms, z_erofs_deflate_avail_strms;
13static struct z_erofs_deflate *z_erofs_deflate_head;
14static DECLARE_WAIT_QUEUE_HEAD(z_erofs_deflate_wq);
15
16module_param_named(deflate_streams, z_erofs_deflate_nstrms, uint, 0444);
17
18void z_erofs_deflate_exit(void)
19{
20	/* there should be no running fs instance */
21	while (z_erofs_deflate_avail_strms) {
22		struct z_erofs_deflate *strm;
23
24		spin_lock(&z_erofs_deflate_lock);
25		strm = z_erofs_deflate_head;
26		if (!strm) {
27			spin_unlock(&z_erofs_deflate_lock);
28			continue;
29		}
30		z_erofs_deflate_head = NULL;
31		spin_unlock(&z_erofs_deflate_lock);
32
33		while (strm) {
34			struct z_erofs_deflate *n = strm->next;
35
36			vfree(strm->z.workspace);
37			kfree(strm);
38			--z_erofs_deflate_avail_strms;
39			strm = n;
40		}
41	}
42}
43
44int __init z_erofs_deflate_init(void)
45{
46	/* by default, use # of possible CPUs instead */
47	if (!z_erofs_deflate_nstrms)
48		z_erofs_deflate_nstrms = num_possible_cpus();
49	return 0;
50}
51
52int z_erofs_load_deflate_config(struct super_block *sb,
53			struct erofs_super_block *dsb, void *data, int size)
54{
55	struct z_erofs_deflate_cfgs *dfl = data;
56	static DEFINE_MUTEX(deflate_resize_mutex);
57	static bool inited;
58
59	if (!dfl || size < sizeof(struct z_erofs_deflate_cfgs)) {
60		erofs_err(sb, "invalid deflate cfgs, size=%u", size);
61		return -EINVAL;
62	}
63
64	if (dfl->windowbits > MAX_WBITS) {
65		erofs_err(sb, "unsupported windowbits %u", dfl->windowbits);
66		return -EOPNOTSUPP;
67	}
68	mutex_lock(&deflate_resize_mutex);
69	if (!inited) {
70		for (; z_erofs_deflate_avail_strms < z_erofs_deflate_nstrms;
71		     ++z_erofs_deflate_avail_strms) {
72			struct z_erofs_deflate *strm;
73
74			strm = kzalloc(sizeof(*strm), GFP_KERNEL);
75			if (!strm)
76				goto failed;
77			/* XXX: in-kernel zlib cannot customize windowbits */
78			strm->z.workspace = vmalloc(zlib_inflate_workspacesize());
79			if (!strm->z.workspace) {
80				kfree(strm);
81				goto failed;
82			}
83
84			spin_lock(&z_erofs_deflate_lock);
85			strm->next = z_erofs_deflate_head;
86			z_erofs_deflate_head = strm;
87			spin_unlock(&z_erofs_deflate_lock);
88		}
89		inited = true;
90	}
91	mutex_unlock(&deflate_resize_mutex);
92	erofs_info(sb, "EXPERIMENTAL DEFLATE feature in use. Use at your own risk!");
93	return 0;
94failed:
95	mutex_unlock(&deflate_resize_mutex);
96	z_erofs_deflate_exit();
97	return -ENOMEM;
98}
99
100int z_erofs_deflate_decompress(struct z_erofs_decompress_req *rq,
101			       struct page **pgpl)
102{
103	const unsigned int nrpages_out =
104		PAGE_ALIGN(rq->pageofs_out + rq->outputsize) >> PAGE_SHIFT;
105	const unsigned int nrpages_in =
106		PAGE_ALIGN(rq->inputsize) >> PAGE_SHIFT;
107	struct super_block *sb = rq->sb;
108	unsigned int insz, outsz, pofs;
109	struct z_erofs_deflate *strm;
110	u8 *kin, *kout = NULL;
111	bool bounced = false;
112	int no = -1, ni = 0, j = 0, zerr, err;
113
114	/* 1. get the exact DEFLATE compressed size */
115	kin = kmap_local_page(*rq->in);
116	err = z_erofs_fixup_insize(rq, kin + rq->pageofs_in,
117			min_t(unsigned int, rq->inputsize,
118			      sb->s_blocksize - rq->pageofs_in));
119	if (err) {
120		kunmap_local(kin);
121		return err;
122	}
123
124	/* 2. get an available DEFLATE context */
125again:
126	spin_lock(&z_erofs_deflate_lock);
127	strm = z_erofs_deflate_head;
128	if (!strm) {
129		spin_unlock(&z_erofs_deflate_lock);
130		wait_event(z_erofs_deflate_wq, READ_ONCE(z_erofs_deflate_head));
131		goto again;
132	}
133	z_erofs_deflate_head = strm->next;
134	spin_unlock(&z_erofs_deflate_lock);
135
136	/* 3. multi-call decompress */
137	insz = rq->inputsize;
138	outsz = rq->outputsize;
139	zerr = zlib_inflateInit2(&strm->z, -MAX_WBITS);
140	if (zerr != Z_OK) {
141		err = -EIO;
142		goto failed_zinit;
143	}
144
145	pofs = rq->pageofs_out;
146	strm->z.avail_in = min_t(u32, insz, PAGE_SIZE - rq->pageofs_in);
147	insz -= strm->z.avail_in;
148	strm->z.next_in = kin + rq->pageofs_in;
149	strm->z.avail_out = 0;
150
151	while (1) {
152		if (!strm->z.avail_out) {
153			if (++no >= nrpages_out || !outsz) {
154				erofs_err(sb, "insufficient space for decompressed data");
155				err = -EFSCORRUPTED;
156				break;
157			}
158
159			if (kout)
160				kunmap_local(kout);
161			strm->z.avail_out = min_t(u32, outsz, PAGE_SIZE - pofs);
162			outsz -= strm->z.avail_out;
163			if (!rq->out[no]) {
164				rq->out[no] = erofs_allocpage(pgpl, rq->gfp);
165				if (!rq->out[no]) {
166					kout = NULL;
167					err = -ENOMEM;
168					break;
169				}
170				set_page_private(rq->out[no],
171						 Z_EROFS_SHORTLIVED_PAGE);
172			}
173			kout = kmap_local_page(rq->out[no]);
174			strm->z.next_out = kout + pofs;
175			pofs = 0;
176		}
177
178		if (!strm->z.avail_in && insz) {
179			if (++ni >= nrpages_in) {
180				erofs_err(sb, "invalid compressed data");
181				err = -EFSCORRUPTED;
182				break;
183			}
184
185			if (kout) { /* unlike kmap(), take care of the orders */
186				j = strm->z.next_out - kout;
187				kunmap_local(kout);
188			}
189			kunmap_local(kin);
190			strm->z.avail_in = min_t(u32, insz, PAGE_SIZE);
191			insz -= strm->z.avail_in;
192			kin = kmap_local_page(rq->in[ni]);
193			strm->z.next_in = kin;
194			bounced = false;
195			if (kout) {
196				kout = kmap_local_page(rq->out[no]);
197				strm->z.next_out = kout + j;
198			}
199		}
200
201		/*
202		 * Handle overlapping: Use bounced buffer if the compressed
203		 * data is under processing; Or use short-lived pages from the
204		 * on-stack pagepool where pages share among the same request
205		 * and not _all_ inplace I/O pages are needed to be doubled.
206		 */
207		if (!bounced && rq->out[no] == rq->in[ni]) {
208			memcpy(strm->bounce, strm->z.next_in, strm->z.avail_in);
209			strm->z.next_in = strm->bounce;
210			bounced = true;
211		}
212
213		for (j = ni + 1; j < nrpages_in; ++j) {
214			struct page *tmppage;
215
216			if (rq->out[no] != rq->in[j])
217				continue;
218			tmppage = erofs_allocpage(pgpl, rq->gfp);
219			if (!tmppage) {
220				err = -ENOMEM;
221				goto failed;
222			}
223			set_page_private(tmppage, Z_EROFS_SHORTLIVED_PAGE);
224			copy_highpage(tmppage, rq->in[j]);
225			rq->in[j] = tmppage;
226		}
227
228		zerr = zlib_inflate(&strm->z, Z_SYNC_FLUSH);
229		if (zerr != Z_OK || !(outsz + strm->z.avail_out)) {
230			if (zerr == Z_OK && rq->partial_decoding)
231				break;
232			if (zerr == Z_STREAM_END && !outsz)
233				break;
234			erofs_err(sb, "failed to decompress %d in[%u] out[%u]",
235				  zerr, rq->inputsize, rq->outputsize);
236			err = -EFSCORRUPTED;
237			break;
238		}
239	}
240failed:
241	if (zlib_inflateEnd(&strm->z) != Z_OK && !err)
242		err = -EIO;
243	if (kout)
244		kunmap_local(kout);
245failed_zinit:
246	kunmap_local(kin);
247	/* 4. push back DEFLATE stream context to the global list */
248	spin_lock(&z_erofs_deflate_lock);
249	strm->next = z_erofs_deflate_head;
250	z_erofs_deflate_head = strm;
251	spin_unlock(&z_erofs_deflate_lock);
252	wake_up(&z_erofs_deflate_wq);
253	return err;
254}
255