Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json, time, torch | |
| from transformers import BartTokenizer, BartForConditionalGeneration, AutoModel, AutoTokenizer | |
| from webshop_lite import dict_to_fake_html | |
| from predict_help import ( | |
| Page, convert_dict_to_actions, convert_html_to_text, | |
| parse_results_amz, parse_item_page_amz, | |
| parse_results_ws, parse_item_page_ws, | |
| parse_results_ebay, parse_item_page_ebay, | |
| WEBSHOP_URL, WEBSHOP_SESSION | |
| ) | |
| ENVIRONMENTS = ['amazon', 'webshop', 'ebay'] | |
| # IL+RL: 'webshop/il-rl-choice-bert-image_1' | |
| # IL: 'webshop/il-choice-bert-image_0' | |
| BERT_MODEL_PATH = 'webshop/il-choice-bert-image_0' | |
| # load IL models | |
| bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') | |
| bart_model = BartForConditionalGeneration.from_pretrained('webshop/il_search_bart') | |
| bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left') | |
| bert_tokenizer.add_tokens(['[button]', '[button_]', '[clicked button]', '[clicked button_]'], special_tokens=True) | |
| bert_model = AutoModel.from_pretrained(BERT_MODEL_PATH, trust_remote_code=True) | |
| def process_str(s): | |
| s = s.lower().replace('"', '').replace("'", "").strip() | |
| s = s.replace('[sep]', '[SEP]') | |
| return s | |
| def process_goal(state): | |
| state = state.lower().replace('"', '').replace("'", "") | |
| state = state.replace('amazon shopping game\ninstruction:', '').replace('\n[button] search [button_]', '').strip() | |
| if ', and price lower than' in state: | |
| state = state.split(', and price lower than')[0] | |
| return state | |
| def data_collator(batch): | |
| state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [], [], [], [], [], [], [] | |
| for sample in batch: | |
| state_input_ids.append(sample['state_input_ids']) | |
| state_attention_mask.append(sample['state_attention_mask']) | |
| action_input_ids.extend(sample['action_input_ids']) | |
| action_attention_mask.extend(sample['action_attention_mask']) | |
| sizes.append(sample['sizes']) | |
| labels.append(sample['labels']) | |
| images.append(sample['images']) | |
| max_state_len = max(sum(x) for x in state_attention_mask) | |
| max_action_len = max(sum(x) for x in action_attention_mask) | |
| return { | |
| 'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len], | |
| 'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len], | |
| 'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len], | |
| 'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len], | |
| 'sizes': torch.tensor(sizes), | |
| 'images': torch.tensor(images), | |
| 'labels': torch.tensor(labels), | |
| } | |
| def bart_predict(input): | |
| input_ids = bart_tokenizer(input)['input_ids'] | |
| input_ids = torch.tensor(input_ids).unsqueeze(0) | |
| output = bart_model.generate(input_ids, max_length=512, num_return_sequences=5, num_beams=5) | |
| return bart_tokenizer.batch_decode(output.tolist(), skip_special_tokens=True)[0] | |
| def bert_predict(obs, info, softmax=True): | |
| valid_acts = info['valid'] | |
| assert valid_acts[0].startswith('click[') | |
| state_encodings = bert_tokenizer(process_str(obs), max_length=512, truncation=True, padding='max_length') | |
| action_encodings = bert_tokenizer(list(map(process_str, valid_acts)), max_length=512, truncation=True, padding='max_length') | |
| batch = { | |
| 'state_input_ids': state_encodings['input_ids'], | |
| 'state_attention_mask': state_encodings['attention_mask'], | |
| 'action_input_ids': action_encodings['input_ids'], | |
| 'action_attention_mask': action_encodings['attention_mask'], | |
| 'sizes': len(valid_acts), | |
| 'images': info['image_feat'].tolist(), | |
| 'labels': 0 | |
| } | |
| batch = data_collator([batch]) | |
| outputs = bert_model(**batch) | |
| if softmax: | |
| idx = torch.multinomial(torch.nn.functional.softmax(outputs.logits[0], dim=0), 1)[0].item() | |
| else: | |
| idx = outputs.logits[0].argmax(0).item() | |
| return valid_acts[idx] | |
| def get_return_value(env, asin, options, search_terms, page_num, product): | |
| asin_url = None | |
| # Determine product URL + options based on environment | |
| if env == 'webshop': | |
| query_str = "+".join(search_terms.split()) | |
| options_str = json.dumps(options) | |
| asin_url = ( | |
| f'{WEBSHOP_URL}/item_page/{WEBSHOP_SESSION}/' | |
| f'{asin}/{query_str}/{page_num}/{options_str}' | |
| ) | |
| else: | |
| asin_url = f"https://www.ebay.com/itm/{asin}" if env == 'ebay' else \ | |
| f"https://www.amazon.com/dp/{asin}" | |
| # Extract relevant fields for product | |
| product_reduced = {k: v for k, v in product.items() if k in ["asin", "Title", "Description", "BulletPoints"]} | |
| product_reduced["Description"] = product_reduced["Description"][:100] + "..." | |
| product_reduced["Features"] = product_reduced.pop("BulletPoints") | |
| product_reduced["Features"] = product_reduced["Features"][:100] + "..." | |
| # Create HTML to show link to product | |
| html = """<!DOCTYPE html><html><head><title>Chosen Product</title></head><body>""" | |
| html += f"""Product Image:<img src="{product["MainImage"]}" height="50px" /><br>""" if len(product["MainImage"]) > 0 else "" | |
| html += f"""Link to Product: | |
| <a href="{asin_url}" style="color:blue;text-decoration:underline;" target="_blank">{asin_url}</a> | |
| </body></html>""" | |
| return product_reduced, options if len(options) > 0 else "None Selected", html | |
| def predict(obs, info): | |
| """ | |
| Given WebShop environment observation and info, predict an action. | |
| """ | |
| valid_acts = info['valid'] | |
| if valid_acts[0].startswith('click['): | |
| return bert_predict(obs, info) | |
| else: | |
| return "search[" + bart_predict(process_goal(obs)) + "]" | |
| def run_episode(goal, env, verbose=True): | |
| """ | |
| Interact with amazon to find a product given input goal. | |
| Input: text goal | |
| Output: a url of found item on amazon. | |
| """ | |
| env = env.lower() | |
| if env not in ENVIRONMENTS: | |
| print(f"[ERROR] Environment {env} not recognized") | |
| obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]" | |
| info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)} | |
| product_map = {} | |
| title_to_asin_map = {} | |
| search_results_cache = {} | |
| visited_asins, clicked_options = set(), set() | |
| sub_page_type, page_type, page_num = None, None, None | |
| search_terms, prod_title, asin = None, None, None | |
| options = {} | |
| for i in range(100): | |
| # Run prediction | |
| action = predict(obs, info) | |
| if verbose: | |
| print("====") | |
| print(action) | |
| # Previous Page Type, Action -> Next Page Type | |
| action_content = action[action.find("[")+1:action.find("]")] | |
| prev_page_type = page_type | |
| if action.startswith('search['): | |
| page_type = Page.RESULTS | |
| search_terms = action_content | |
| page_num = 1 | |
| elif action.startswith('click['): | |
| if action.startswith('click[item -'): | |
| prod_title = action_content[len("item -"):].strip() | |
| found = False | |
| for key in title_to_asin_map: | |
| if prod_title == key: | |
| asin = title_to_asin_map[key] | |
| page_type = Page.ITEM_PAGE | |
| visited_asins.add(asin) | |
| found = True | |
| break | |
| if not found: | |
| raise Exception("Product to click not found") | |
| elif any(x.value in action for x in [Page.DESC, Page.FEATURES, Page.REVIEWS]): | |
| page_type = Page.SUB_PAGE | |
| sub_page_type = Page(action_content.lower()) | |
| elif action == 'click[< prev]': | |
| if sub_page_type is not None: | |
| page_type, sub_page_type = Page.ITEM_PAGE, None | |
| elif prev_page_type == Page.ITEM_PAGE: | |
| page_type = Page.RESULTS | |
| options, clicked_options = {}, set() | |
| elif prev_page_type == Page.RESULTS and page_num > 1: | |
| page_type = Page.RESULTS | |
| page_num -= 1 | |
| elif action == 'click[next >]': | |
| page_type = Page.RESULTS | |
| page_num += 1 | |
| elif action.lower() == 'click[back to search]': | |
| page_type = Page.SEARCH | |
| elif action == 'click[buy now]': | |
| return get_return_value(env, asin, options, search_terms, page_num, product_map[asin]) | |
| elif prev_page_type == Page.ITEM_PAGE: | |
| found = False | |
| for opt_name, opt_values in product_map[asin]["options"].items(): | |
| if action_content in opt_values: | |
| options[opt_name] = action_content | |
| page_type = Page.ITEM_PAGE | |
| clicked_options.add(action_content) | |
| found = True | |
| break | |
| if not found: | |
| raise Exception("Unrecognized action: " + action) | |
| else: | |
| raise Exception("Unrecognized action:" + action) | |
| if verbose: | |
| print(f"Parsing {page_type.value} page...") | |
| # URL -> Real HTML -> Dict of Info | |
| if page_type == Page.RESULTS: | |
| if search_terms in search_results_cache: | |
| data = search_results_cache[search_terms] | |
| if verbose: | |
| print(f"Loading cached results page for \"{search_terms}\"") | |
| else: | |
| begin = time.time() | |
| if env == 'amazon': | |
| data = parse_results_amz(search_terms, page_num, verbose) | |
| if env == 'webshop': | |
| data = parse_results_ws(search_terms, page_num, verbose) | |
| if env == 'ebay': | |
| data = parse_results_ebay(search_terms, page_num, verbose) | |
| end = time.time() | |
| if verbose: | |
| print(f"Parsing search results took {end-begin} seconds") | |
| search_results_cache[search_terms] = data | |
| for d in data: | |
| title_to_asin_map[d['Title']] = d['asin'] | |
| elif page_type == Page.ITEM_PAGE or page_type == Page.SUB_PAGE: | |
| if asin in product_map: | |
| if verbose: | |
| print("Loading cached item page for", asin) | |
| data = product_map[asin] | |
| else: | |
| begin = time.time() | |
| if env == 'amazon': | |
| data = parse_item_page_amz(asin, verbose) | |
| if env == 'webshop': | |
| data = parse_item_page_ws(asin, search_terms, page_num, options, verbose) | |
| if env == 'ebay': | |
| data = parse_item_page_ebay(asin, verbose) | |
| end = time.time() | |
| if verbose: | |
| print("Parsing item page took", end-begin, "seconds") | |
| product_map[asin] = data | |
| elif page_type == Page.SEARCH: | |
| if verbose: | |
| print("Executing search") | |
| obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]" | |
| info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)} | |
| continue | |
| else: | |
| raise Exception("Page of type `", page_type, "` not found") | |
| # Dict of Info -> Fake HTML -> Text Observation | |
| begin = time.time() | |
| html_str = dict_to_fake_html(data, page_type, asin, sub_page_type, options, product_map, goal) | |
| obs = convert_html_to_text(html_str, simple=False, clicked_options=clicked_options, visited_asins=visited_asins) | |
| end = time.time() | |
| if verbose: | |
| print("[Page Info -> WebShop HTML -> Observation] took", end-begin, "seconds") | |
| # Dict of Info -> Valid Action State (Info) | |
| begin = time.time() | |
| prod_arg = product_map if page_type == Page.ITEM_PAGE else data | |
| info = convert_dict_to_actions(page_type, prod_arg, asin, page_num) | |
| end = time.time() | |
| if verbose: | |
| print("Extracting available actions took", end-begin, "seconds") | |
| if i == 50: | |
| return get_return_value(env, asin, options, search_terms, page_num, product_map[asin]) | |
| gr.Interface( | |
| fn=run_episode, | |
| inputs=[ | |
| gr.inputs.Textbox(lines=7, label="Input Text"), | |
| gr.inputs.Radio(['Amazon', 'eBay'], type="value", default="Amazon", label='Environment') | |
| ], | |
| outputs=[ | |
| gr.outputs.JSON(label="Selected Product"), | |
| gr.outputs.JSON(label="Selected Options"), | |
| gr.outputs.HTML() | |
| ], | |
| examples=[ | |
| ["I want to find a gold floor lamp with a glass shade and a nickel finish that i can use for my living room, and price lower than 270.00 dollars", "Amazon"], | |
| ["I need some cute heart-shaped glittery cupcake picks as a gift to bring to a baby shower", "Amazon"], | |
| ["I want to buy ballet shoes which have rubber sole in grey suede color and a size of 6", "Amazon"], | |
| ["I would like a 7 piece king comforter set decorated with flowers and is machine washable", "Amazon"], | |
| ["I'm trying to find white bluetooth speakers that are not only water resistant but also come with stereo sound", "eBay"], | |
| ["find me the soy free 3.5 ounce 4-pack of dang thai rice chips, and make sure they are the aged cheddar flavor. i also need the ones in the resealable bags", "eBay"], | |
| ["I am looking for a milk chocolate of 1 pound size in a single pack for valentine day", "eBay"], | |
| ["I'm looking for a mini pc intel core desktop computer which supports with windows 11", "eBay"] | |
| ], | |
| title="WebShop", | |
| article="<p style='padding-top:15px;text-align:center;'>To learn more about this project, check out the <a href='https://webshop-pnlp.github.io/' target='_blank'>project page</a>!</p>", | |
| description="<p style='text-align:center;'>Sim-to-real transfer of agent trained on WebShop to search a desired product on Amazon from any natural language query!</p>", | |
| ).launch(inline=False) | |