11package com .github .davidfantasy .mybatisplus .generatorui .controller ;
22
3+ import cn .hutool .core .io .IoUtil ;
34import com .github .davidfantasy .mybatisplus .generatorui .common .Result ;
45import com .github .davidfantasy .mybatisplus .generatorui .common .ResultGenerator ;
56import com .github .davidfantasy .mybatisplus .generatorui .common .ServiceException ;
1213import com .google .common .collect .Maps ;
1314import lombok .extern .slf4j .Slf4j ;
1415import org .springframework .beans .factory .annotation .Autowired ;
16+ import org .springframework .http .MediaType ;
17+ import org .springframework .http .ResponseEntity ;
1518import org .springframework .web .bind .annotation .*;
1619import org .springframework .web .multipart .MultipartFile ;
1720
18- import javax .servlet .http .HttpServletResponse ;
19- import java .io .*;
21+ import java .io .File ;
22+ import java .io .IOException ;
23+ import java .io .InputStream ;
2024import java .nio .file .Files ;
2125import java .util .List ;
2226import java .util .Map ;
@@ -34,41 +38,41 @@ public class TemplateController {
3438 private OutputFileInfoService outputFileInfoService ;
3539
3640 @ GetMapping ("/download" )
37- public void download (HttpServletResponse res , @ RequestParam String fileType ) throws IOException {
41+ public ResponseEntity < byte []> download (@ RequestParam String fileType ) throws IOException {
3842 if (Strings .isNullOrEmpty (fileType )) {
3943 log .error ("fileType不能为空" );
40- return ;
44+ return ResponseEntity . badRequest (). build () ;
4145 }
4246 UserConfig userConfig = userConfigStore .getUserConfigFromFile ();
4347 if (userConfig == null ) {
4448 InputStream tplIn = TemplateUtil .getBuiltInTemplate (fileType );
45- download (res , tplIn );
46- return ;
49+ return toDownloadEntity (tplIn );
4750 }
4851 List <OutputFileInfo > fileInfos = userConfig .getOutputFiles ();
4952 for (OutputFileInfo fileInfo : fileInfos ) {
5053 if (fileType .equals (fileInfo .getFileType ())) {
5154 if (fileInfo .isBuiltIn ()
5255 && Strings .isNullOrEmpty (fileInfo .getTemplatePath ())) {
5356 InputStream tplIn = TemplateUtil .getBuiltInTemplate (fileType );
54- download ( res , tplIn );
57+ return toDownloadEntity ( tplIn );
5558 } else {
5659 String tplPath = fileInfo .getTemplatePath ();
5760 if (tplPath .startsWith ("file:" )) {
5861 tplPath = tplPath .replaceFirst ("file:" , "" );
5962 }
6063 File tplFile = new File (tplPath );
6164 if (tplFile .exists ()) {
62- download ( res , Files .newInputStream (tplFile .toPath ()));
65+ return toDownloadEntity ( Files .newInputStream (tplFile .toPath ()));
6366 } else {
6467 throw new ServiceException ("未找到模板文件:" + fileInfo .getTemplatePath ());
6568 }
6669 }
67- break ;
6870 }
6971 }
72+ return ResponseEntity .notFound ().build ();
7073 }
7174
75+
7276 @ PostMapping ("/upload" )
7377 public Result upload (@ RequestParam ("file" ) MultipartFile file , @ RequestParam ("fileType" ) String fileType ) {
7478 Map <String , Object > params = Maps .newHashMap ();
@@ -78,27 +82,17 @@ public Result upload(@RequestParam("file") MultipartFile file, @RequestParam("fi
7882 return ResultGenerator .genSuccessResult (params );
7983 }
8084
81- private void download (HttpServletResponse res , InputStream tplIn ) throws UnsupportedEncodingException {
82- if (tplIn != null ) {
83- res .setCharacterEncoding ("utf-8" );
84- res .setContentType ("multipart/form-data;charset=UTF-8" );
85- try {
86- OutputStream os = res .getOutputStream ();
87- byte [] b = new byte [2048 ];
88- int length ;
89- while ((length = tplIn .read (b )) > 0 ) {
90- os .write (b , 0 , length );
91- }
92- } catch (Exception e ) {
93- e .printStackTrace ();
94- } finally {
95- try {
96- tplIn .close ();
97- } catch (IOException ignored ) {
98- }
99- }
85+ /**
86+ * 从输入流构建http响应
87+ * @param tplIn 流
88+ * @return http响应
89+ */
90+ private ResponseEntity <byte []> toDownloadEntity (InputStream tplIn ) {
91+ if (tplIn == null ) {
92+ return ResponseEntity .notFound ().build ();
10093 }
101- }
94+ return ResponseEntity . ok (). contentType ( MediaType . MULTIPART_FORM_DATA ). body ( IoUtil . readBytes ( tplIn ));
10295
10396
97+ }
10498}
0 commit comments