diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py index 92ec1325a407..ca30f1c69f89 100644 --- a/python/mxnet/recordio.py +++ b/python/mxnet/recordio.py @@ -177,10 +177,14 @@ def pack_img(header, img, quality=80, img_fmt='.jpg'): The packed string """ assert opencv_available - if img_fmt == '.jpg': + jpg_formats = set(['.jpg', '.jpeg', '.JPG', '.JPEG']) + png_formats = set(['.png', '.PNG']) + encode_params = None + if img_fmt in jpg_formats: encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] - elif img_fmt == '.png': + elif img_fmt in png_formats: encode_params = [cv2.IMWRITE_PNG_COMPRESSION, quality] + ret, buf = cv2.imencode(img_fmt, img, encode_params) assert ret, 'failed encoding image' return pack(header, buf.tostring())