import os from typing import ( Any, cast, ) from unittest.mock import ( AsyncMock, MagicMock, patch, ) import pytest from openai import APIError from galaxy_test.driver.integration_util import IntegrationTestCase openai = pytest.importorskip("openai") TEST_VISUALIZATION_PLUGINS_DIR = os.path.join(os.path.dirname(__file__), "test_visualization_plugins") def _create_chat_payload(extra=None): payload = { "messages": [{"role": "user", "content": "hi"}], "tools": [], } if extra: payload.update(extra) return payload class TestVisualizationPluginsApi(IntegrationTestCase): """Tests for the visualization plugins API endpoints.""" @classmethod def handle_galaxy_config_kwds(cls, config) -> None: config["ai_api_key"] = "ai_api_key" config["ai_api_base_url"] = "ai_api_base_url" config["ai_model"] = "ai_model" config["visualization_plugins_directory"] = TEST_VISUALIZATION_PLUGINS_DIR def test_index(self): """Test that GET /api/plugins returns a list of plugins.""" response = self._get("plugins") self._assert_status_code_is(response, 200) plugins = response.json() assert isinstance(plugins, list) def test_show_returns_all_fields(self): """Test that GET /api/plugins/{id} returns all expected fields including params, tags, tests, help, data_sources.""" response = self._get("plugins/jupyterlite") self._assert_status_code_is(response, 200) plugin = response.json() # Verify required fields assert plugin["name"] == "jupyterlite" assert plugin["html"] == "JupyterLite Test" assert plugin["description"] == "Test fixture for visualization plugin integration tests" assert plugin["embeddable"] is False assert "entry_point" in plugin assert "href" in plugin # Verify params are returned correctly assert "params" in plugin params = plugin["params"] assert "dataset_id" in params assert params["dataset_id"]["required"] is True assert params["dataset_id"]["type"] == "str" # Verify data_sources are returned assert "data_sources" in plugin data_sources = plugin["data_sources"] assert len(data_sources) >= 1 assert data_sources[0]["model_class"] == "HistoryDatasetAssociation" # Verify specs are returned assert "specs" in plugin assert plugin["specs"]["custom_setting"] == "test_value" # Verify tags are returned assert "tags" in plugin tags = plugin["tags"] assert "Test" in tags assert "Integration" in tags # Verify help is returned assert "help" in plugin assert "test help text" in plugin["help"] # Verify tests are returned assert "tests" in plugin assert len(plugin["tests"]) >= 1 def _post_payload(self, payload=None, anon=False): return self._post("plugins/jupyterlite/chat/completions", payload, json=True, anon=anon) @patch("galaxy.webapps.galaxy.api.plugins.AsyncOpenAI") def test_non_streaming_success(self, mock_client): mock_response = MagicMock() mock_response.model_dump.return_value = {"id": "test", "choices": []} mock_instance = MagicMock() mock_instance.chat.completions.create = AsyncMock(return_value=mock_response) mock_client.return_value = mock_instance payload = _create_chat_payload() response = self._post_payload(payload, anon=False) self._assert_status_code_is(response, 200) assert response.json()["id"] == "test" @patch("galaxy.webapps.galaxy.api.plugins.AsyncOpenAI") def test_streaming_success(self, mock_client): async def stream_gen(): chunk1 = MagicMock() chunk1.model_dump.return_value = {"choices": [{"delta": {"content": "hello"}}]} chunk2 = MagicMock() chunk2.model_dump.return_value = {"choices": [{"delta": {"content": "world"}}]} yield chunk1 yield chunk2 mock_instance = MagicMock() mock_instance.chat.completions.create = AsyncMock(return_value=stream_gen()) mock_instance.close = AsyncMock() mock_client.return_value = mock_instance payload = _create_chat_payload({"stream": True}) response = self._post_payload(payload, anon=False) self._assert_status_code_is(response, 200) body = response.text assert body.count("data:") == 3 assert "hello" in body assert "world" in body assert body.rstrip().endswith("data: [DONE]") assert mock_instance.chat.completions.create.called assert mock_instance.close.called def test_tools_exceed_max(self): payload = _create_chat_payload( {"tools": [{"type": "function", "function": {"name": "f", "parameters": {}}}] * 129} ) response = self._post_payload(payload) assert "Number of tools exceeded" in response.json()["error"]["message"] def test_tool_schema_too_large(self): big_params = {"x": "a" * 20000} payload = _create_chat_payload( {"tools": [{"type": "function", "function": {"name": "f", "parameters": big_params}}]} ) response = self._post_payload(payload) assert "Tool schema too large" in response.json()["error"]["message"] def test_exceed_max_messages(self): msgs = {"messages": [{"role": "user", "content": "x"}] * (1024 + 1)} payload = _create_chat_payload(msgs) response = self._post_payload(payload) assert "You have exceeded the number of maximum messages" in response.json()["error"]["message"] @patch("galaxy.webapps.galaxy.api.plugins.AsyncOpenAI") def test_assistant_content_and_tool_calls_preserved(self, mock_client): mock_response = MagicMock() mock_response.model_dump.return_value = {"id": "test", "choices": []} mock_instance = MagicMock() mock_instance.chat.completions.create = AsyncMock(return_value=mock_response) mock_client.return_value = mock_instance payload = { "messages": [ { "role": "assistant", "content": "I will call a tool", "tool_calls": [ { "id": "call_1", "type": "function", "function": { "name": "choose_process", "arguments": "{}", }, } ], } ], "tools": [], } response = self._post_payload(payload) self._assert_status_code_is(response, 200) call_kwargs = mock_instance.chat.completions.create.call_args.kwargs forwarded_messages = call_kwargs["messages"] assistant_msgs = [m for m in forwarded_messages if m["role"] == "assistant"] assert len(assistant_msgs) == 1 assert assistant_msgs[0]["content"] == "I will call a tool" assert "tool_calls" in assistant_msgs[0] assert assistant_msgs[0]["tool_calls"][0]["function"]["name"] == "choose_process" assert assistant_msgs[0]["tool_calls"][0]["function"]["arguments"] == "{}" @patch("galaxy.webapps.galaxy.api.plugins.AsyncOpenAI") def test_tool_description_preserved(self, mock_client): mock_response = MagicMock() mock_response.model_dump.return_value = {"id": "test", "choices": []} mock_instance = MagicMock() mock_instance.chat.completions.create = AsyncMock(return_value=mock_response) mock_client.return_value = mock_instance payload = _create_chat_payload( { "tools": [ { "type": "function", "function": { "name": "choose_process", "description": "Select a processing step", "parameters": {"type": "object"}, }, } ] } ) response = self._post_payload(payload) self._assert_status_code_is(response, 200) call_kwargs = mock_instance.chat.completions.create.call_args.kwargs forwarded_tools = call_kwargs["tools"] assert forwarded_tools[0]["function"]["description"] == "Select a processing step" @patch("galaxy.webapps.galaxy.api.plugins.AsyncOpenAI") def test_tool_with_none_parameters_normalized(self, mock_client): mock_response = MagicMock() mock_response.model_dump.return_value = {"id": "test", "choices": []} mock_instance = MagicMock() mock_instance.chat.completions.create = AsyncMock(return_value=mock_response) mock_client.return_value = mock_instance payload = _create_chat_payload( { "tools": [ { "type": "function", "function": { "name": "choose_process", "description": "Select a processing step", "parameters": None, }, } ] } ) response = self._post_payload(payload) self._assert_status_code_is(response, 200) call_kwargs = mock_instance.chat.completions.create.call_args.kwargs forwarded_tools = call_kwargs["tools"] assert forwarded_tools[0]["function"]["parameters"] == { "type": "object", "properties": {}, } @patch("galaxy.webapps.galaxy.api.plugins.AsyncOpenAI") def test_tool_with_missing_parameters_normalized(self, mock_client): mock_response = MagicMock() mock_response.model_dump.return_value = {"id": "test", "choices": []} mock_instance = MagicMock() mock_instance.chat.completions.create = AsyncMock(return_value=mock_response) mock_client.return_value = mock_instance payload = _create_chat_payload( { "tools": [ { "type": "function", "function": { "name": "choose_process", "description": "Select a processing step", }, } ] } ) response = self._post_payload(payload) self._assert_status_code_is(response, 200) call_kwargs = mock_instance.chat.completions.create.call_args.kwargs forwarded_tools = call_kwargs["tools"] assert forwarded_tools[0]["function"]["parameters"] == { "type": "object", "properties": {}, } @patch("galaxy.webapps.galaxy.api.plugins.AsyncOpenAI") def test_provider_error_body_forwarded(self, mock_client): class MockOpenAIError(APIError): def __init__(self): super().__init__( message="original error message", request=cast(Any, object()), body={ "message": "original error message", "type": "api_error", "param": None, "code": None, }, ) self.status_code = 404 mock_instance = MagicMock() mock_instance.chat.completions.create = AsyncMock(side_effect=MockOpenAIError()) mock_client.return_value = mock_instance response = self._post_payload(_create_chat_payload()) self._assert_status_code_is(response, 404) body = response.json() assert body["error"]["message"] == "original error message" assert body["error"]["type"] == "api_error" class TestPluginsInferenceServicesConfig(IntegrationTestCase): """Tests for inference_services config resolution in plugins.""" @classmethod def handle_galaxy_config_kwds(cls, config) -> None: config["ai_api_key"] = "global_key" config["ai_api_base_url"] = "http://global-url" config["ai_model"] = "global_model" config["visualization_plugins_directory"] = TEST_VISUALIZATION_PLUGINS_DIR config["inference_services"] = { "default": { "model": "default_model", "api_key": "default_key", "api_base_url": "http://default-url", }, "jupyterlite": { "model": "jupyterlite_model", "api_key": "jupyterlite_key", "api_base_url": "http://jupyterlite-url", }, } def _post_payload(self, payload=None, anon=False): return self._post("plugins/jupyterlite/chat/completions", payload, json=True, anon=anon) @patch("galaxy.webapps.galaxy.api.plugins.AsyncOpenAI") def test_plugin_specific_config_used(self, mock_client): """Plugin-specific inference_services config overrides default and global.""" mock_response = MagicMock() mock_response.model_dump.return_value = {"id": "test", "choices": []} mock_instance = MagicMock() mock_instance.chat.completions.create = AsyncMock(return_value=mock_response) mock_client.return_value = mock_instance response = self._post_payload(payload=_create_chat_payload()) self._assert_status_code_is(response, 200) call_kwargs = mock_instance.chat.completions.create.call_args.kwargs assert call_kwargs["model"] == "jupyterlite_model" client_kwargs = mock_client.call_args.kwargs assert client_kwargs["api_key"] == "jupyterlite_key" assert client_kwargs["base_url"] == "http://jupyterlite-url" class TestPluginsInferenceServicesDefault(IntegrationTestCase): """Tests that inference_services.default is used when no plugin-specific config exists.""" @classmethod def handle_galaxy_config_kwds(cls, config) -> None: config["ai_api_key"] = "global_key" config["ai_api_base_url"] = "http://global-url" config["ai_model"] = "global_model" config["visualization_plugins_directory"] = TEST_VISUALIZATION_PLUGINS_DIR config["inference_services"] = { "default": { "model": "default_model", "api_key": "default_key", "api_base_url": "http://default-url", }, } def _post_payload(self, payload=None, anon=False): return self._post("plugins/jupyterlite/chat/completions", payload, json=True, anon=anon) @patch("galaxy.webapps.galaxy.api.plugins.AsyncOpenAI") def test_default_config_fallback(self, mock_client): """inference_services.default is used when no plugin-specific entry exists.""" mock_response = MagicMock() mock_response.model_dump.return_value = {"id": "test", "choices": []} mock_instance = MagicMock() mock_instance.chat.completions.create = AsyncMock(return_value=mock_response) mock_client.return_value = mock_instance response = self._post_payload(payload=_create_chat_payload()) self._assert_status_code_is(response, 200) call_kwargs = mock_instance.chat.completions.create.call_args.kwargs assert call_kwargs["model"] == "default_model" client_kwargs = mock_client.call_args.kwargs assert client_kwargs["api_key"] == "default_key" assert client_kwargs["base_url"] == "http://default-url"