From 5b4159602f004b9ebfa0a1b6badece51d2c5f849 Mon Sep 17 00:00:00 2001 From: jackform <296256067@qq.com> Date: Fri, 23 Dec 2022 12:59:51 +0000 Subject: [PATCH 1/6] remove useless print --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index e6d3453..1581715 100644 --- a/train.py +++ b/train.py @@ -134,8 +134,8 @@ def load_dataset(logger, args): # test # input_list_train = input_list_train[:24] # input_list_val = input_list_val[:24] - print(f'train {len(input_list_train)} {input_list_train}') - print(f'valid {len(input_list_val)} {input_list_val}') + #print(f'train {len(input_list_train)} {input_list_train}') + #print(f'valid {len(input_list_val)} {input_list_val}') train_dataset = MyDataset(input_list_train, args.max_len) val_dataset = MyDataset(input_list_val, args.max_len) From 6beda20f0b9dd60ce833c4a5870cf50b44d204a4 Mon Sep 17 00:00:00 2001 From: jackform <296256067@qq.com> Date: Fri, 23 Dec 2022 14:39:50 +0000 Subject: [PATCH 2/6] add print --- train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train.py b/train.py index 1581715..2041766 100644 --- a/train.py +++ b/train.py @@ -382,6 +382,7 @@ def main(): # 当用户使用GPU,并且GPU可用时 args.cuda = torch.cuda.is_available() and not args.no_cuda device = 'cuda:0' if args.cuda else 'cpu' + print(f'args no_cuda:{args.no_cuda} cuda {args.cuda}') args.device = device logger.info('using device:{}'.format(device)) From 28c3f6a2d5e6556b0024f11f7e15d2f47d17e2cf Mon Sep 17 00:00:00 2001 From: jackform <296256067@qq.com> Date: Fri, 23 Dec 2022 14:57:54 +0000 Subject: [PATCH 3/6] add print and default use gpu --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 2041766..d54ad46 100644 --- a/train.py +++ b/train.py @@ -380,9 +380,10 @@ def main(): # 创建日志对象 logger = create_logger(args) # 当用户使用GPU,并且GPU可用时 - args.cuda = torch.cuda.is_available() and not args.no_cuda - device = 'cuda:0' if args.cuda else 'cpu' + #args.cuda = torch.cuda.is_available() and not args.no_cuda + args.cuda = torch.cuda.is_available() print(f'args no_cuda:{args.no_cuda} cuda {args.cuda}') + device = 'cuda:0' if args.cuda else 'cpu' args.device = device logger.info('using device:{}'.format(device)) From 0e0482e5d65ad4af759e91306c1333241cfa6d4d Mon Sep 17 00:00:00 2001 From: jackform <296256067@qq.com> Date: Fri, 23 Dec 2022 15:01:52 +0000 Subject: [PATCH 4/6] use gpu --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index d54ad46..046481c 100644 --- a/train.py +++ b/train.py @@ -381,8 +381,8 @@ def main(): logger = create_logger(args) # 当用户使用GPU,并且GPU可用时 #args.cuda = torch.cuda.is_available() and not args.no_cuda - args.cuda = torch.cuda.is_available() - print(f'args no_cuda:{args.no_cuda} cuda {args.cuda}') + args.cuda = True + print(f'====>args no_cuda:{args.no_cuda} cuda {args.cuda}') device = 'cuda:0' if args.cuda else 'cpu' args.device = device logger.info('using device:{}'.format(device)) From 3dd699df7ba7d2032e30f0479449e5f1eae0c016 Mon Sep 17 00:00:00 2001 From: jackform <296256067@qq.com> Date: Fri, 23 Dec 2022 15:11:43 +0000 Subject: [PATCH 5/6] print cuda situation --- train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train.py b/train.py index 046481c..83b7616 100644 --- a/train.py +++ b/train.py @@ -382,6 +382,7 @@ def main(): # 当用户使用GPU,并且GPU可用时 #args.cuda = torch.cuda.is_available() and not args.no_cuda args.cuda = True + print(f'====>cuda:{torch.cuda.is_available()}') print(f'====>args no_cuda:{args.no_cuda} cuda {args.cuda}') device = 'cuda:0' if args.cuda else 'cpu' args.device = device From fe6ba7b55c074a6c51aa3ec6d2e6fd619dd72df9 Mon Sep 17 00:00:00 2001 From: jackform <296256067@qq.com> Date: Thu, 5 Jan 2023 09:09:05 +0000 Subject: [PATCH 6/6] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E8=B7=A8=E5=9F=9F?= =?UTF-8?q?=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- interact.py | 41 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/interact.py b/interact.py index e30ade0..2c546cf 100644 --- a/interact.py +++ b/interact.py @@ -45,8 +45,8 @@ def set_args(): parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数") # parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的') - parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断') - parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度") + parser.add_argument('--max_len', type=int, default=250, help='每个utterance的最大长度,超过指定长度则进行截断') + parser.add_argument('--max_history_len', type=int, default=1, help="dialogue history的最大长度") parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测') return parser.parse_args() @@ -110,6 +110,25 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') return logits +region_words = ['保险', '人寿保险', '平安', '医疗保险', '汽车保险', '医疗险', '家财险', '意外险', '保险公司', '重疾险', '保险费用', + '人寿', '寿险', '保险费', '险', '人身保险', '理赔', '平安保险', '保费', '车险', '保险金', '被保险人', '投保人', + '保险费率', '财产保险', '保单', '续保', '投保', '核保', '医保', + # 特定保险begin + '综合意外', '鹏城保', '百万家财', '水滴保', '福禄鑫尊', '国寿瑞鑫', '鑫裕金', '鑫尊宝', '国寿福', '同佑e生', '金佑人生', + '金福合家欢', '重庆渝惠保', '国寿康宁', '泰康贴心保', '新冠隔离津贴', '火车隔离津贴', '外卖准时宝', '泰康', '悟空保', '南充充惠保', + '春城惠民保', '太平福禄御禧', '太平洋金佑', '国华金如意', '医疗保障', '医疗保健', + # 特定保险end + # 特定术语begin + '告知义务', '说明义务', '现金价值', '犹豫期', '宽限期', '条款', + # 特定术语end + '重大疾病', + # '年金', + # '退休计划', + # '养老金', + '免赔', '保障期', '退保', '保额', '受益人', '可以保', '能保么', + ] + + def main(): args = set_args() logger = create_logger(args) @@ -139,12 +158,30 @@ def main(): try: text = input("user:") # text = "你好" + if args.save_samples_path: samples_file.write("user:{}\n".format(text)) text_ids = tokenizer.encode(text, add_special_tokens=False) history.append(text_ids) input_ids = [tokenizer.cls_token_id] # 每个input以[CLS]为开头 + #先做领域识别, 在关键字里面的为保险领域的 + in_region = False + for region_word in region_words: + if not text.__contains__(region_word): + pass + else: + in_region = True + break + if not in_region: + response_text = "这方面的问题,安安不太明白,要不您换个问法再试试,或许安安就能明白啦!" + response = tokenizer.encode(response_text, add_special_tokens=False) + history.append(response) + print("chatbot:" + "".join(response_text)) + if args.save_samples_path: + samples_file.write("chatbot:{}\n".format("".join(response_text))) + continue + for history_id, history_utr in enumerate(history[-args.max_history_len:]): input_ids.extend(history_utr) input_ids.append(tokenizer.sep_token_id)