Spaces:
Running
Running
| from unittest.mock import MagicMock, patch | |
| from urllib.parse import urljoin | |
| import pytest | |
| from langchain_ollama import ChatOllama | |
| from langflow.components.models import ChatOllamaComponent | |
| def component(): | |
| return ChatOllamaComponent() | |
| def test_get_model_success(mock_get, component): | |
| mock_response = MagicMock() | |
| mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} | |
| mock_response.raise_for_status.return_value = None | |
| mock_get.return_value = mock_response | |
| base_url = "http://localhost:11434" | |
| model_names = component.get_model(base_url) | |
| expected_url = urljoin(base_url, "/api/tags") | |
| mock_get.assert_called_once_with(expected_url) | |
| assert model_names == ["model1", "model2"] | |
| def test_get_model_failure(mock_get, component): | |
| # Mock the response for the HTTP GET request to raise an exception | |
| mock_get.side_effect = Exception("HTTP request failed") | |
| url = "http://localhost:11434/api/tags" | |
| # Assert that the ValueError is raised when an exception occurs | |
| with pytest.raises(ValueError, match="Could not retrieve models"): | |
| component.get_model(url) | |
| def test_update_build_config_mirostat_disabled(component): | |
| build_config = { | |
| "mirostat_eta": {"advanced": False, "value": 0.1}, | |
| "mirostat_tau": {"advanced": False, "value": 5}, | |
| } | |
| field_value = "Disabled" | |
| field_name = "mirostat" | |
| updated_config = component.update_build_config(build_config, field_value, field_name) | |
| assert updated_config["mirostat_eta"]["advanced"] is True | |
| assert updated_config["mirostat_tau"]["advanced"] is True | |
| assert updated_config["mirostat_eta"]["value"] is None | |
| assert updated_config["mirostat_tau"]["value"] is None | |
| def test_update_build_config_mirostat_enabled(component): | |
| build_config = { | |
| "mirostat_eta": {"advanced": False, "value": None}, | |
| "mirostat_tau": {"advanced": False, "value": None}, | |
| } | |
| field_value = "Mirostat 2.0" | |
| field_name = "mirostat" | |
| updated_config = component.update_build_config(build_config, field_value, field_name) | |
| assert updated_config["mirostat_eta"]["advanced"] is False | |
| assert updated_config["mirostat_tau"]["advanced"] is False | |
| assert updated_config["mirostat_eta"]["value"] == 0.2 | |
| assert updated_config["mirostat_tau"]["value"] == 10 | |
| def test_update_build_config_model_name(mock_get, component): | |
| # Mock the response for the HTTP GET request | |
| mock_response = MagicMock() | |
| mock_response.json.return_value = {"models": [{"name": "model1"}, {"name": "model2"}]} | |
| mock_response.raise_for_status.return_value = None | |
| mock_get.return_value = mock_response | |
| build_config = { | |
| "base_url": {"load_from_db": False, "value": None}, | |
| "model_name": {"options": []}, | |
| } | |
| field_value = None | |
| field_name = "model_name" | |
| updated_config = component.update_build_config(build_config, field_value, field_name) | |
| assert updated_config["model_name"]["options"] == ["model1", "model2"] | |
| def test_update_build_config_keep_alive(component): | |
| build_config = {"keep_alive": {"value": None, "advanced": False}} | |
| field_value = "Keep" | |
| field_name = "keep_alive_flag" | |
| updated_config = component.update_build_config(build_config, field_value, field_name) | |
| assert updated_config["keep_alive"]["value"] == "-1" | |
| assert updated_config["keep_alive"]["advanced"] is True | |
| field_value = "Immediately" | |
| updated_config = component.update_build_config(build_config, field_value, field_name) | |
| assert updated_config["keep_alive"]["value"] == "0" | |
| assert updated_config["keep_alive"]["advanced"] is True | |
| def test_build_model(_mock_chat_ollama, component): # noqa: PT019 | |
| component.base_url = "http://localhost:11434" | |
| component.model_name = "llama3.1" | |
| component.mirostat = "Mirostat 2.0" | |
| component.mirostat_eta = 0.2 # Ensure this is set as a float | |
| component.mirostat_tau = 10.0 # Ensure this is set as a float | |
| component.temperature = 0.2 | |
| component.verbose = True | |
| model = component.build_model() | |
| assert isinstance(model, ChatOllama) | |
| assert model.base_url == "http://localhost:11434" | |
| assert model.model == "llama3.1" | |