diff --git a/Common/Data/SubscriptionDataSource.cs b/Common/Data/SubscriptionDataSource.cs index 12d9a883ffd8..2e7daf580d60 100644 --- a/Common/Data/SubscriptionDataSource.cs +++ b/Common/Data/SubscriptionDataSource.cs @@ -16,6 +16,7 @@ using System; using System.Linq; using System.Collections.Generic; +using Python.Runtime; using static QuantConnect.StringExtensions; namespace QuantConnect.Data @@ -79,7 +80,7 @@ public SubscriptionDataSource(string source, SubscriptionTransportMedium transpo /// The transport medium to be used to retrieve the subscription's data from the source /// The format of the data within the source public SubscriptionDataSource(string source, SubscriptionTransportMedium transportMedium, FileFormat format) - : this(source, transportMedium, format, null) + : this(source, transportMedium, format, (IEnumerable>)null) { } @@ -99,6 +100,19 @@ public SubscriptionDataSource(string source, SubscriptionTransportMedium transpo Headers = headers?.ToList() ?? _empty; } + /// + /// Initializes a new instance of the class with + /// including the specified header values as a Python dictionary. + /// + /// The subscription's data source location + /// The transport medium to be used to retrieve the subscription's data from the source + /// The format of the data within the source + /// The Python dictionary containing the headers to be used for this source + public SubscriptionDataSource(string source, SubscriptionTransportMedium transportMedium, FileFormat format, PyObject headers) + : this(source, transportMedium, format, headers == null ? null : headers.ConvertToDictionary()) + { + } + /// /// Indicates whether the current object is equal to another object of the same type. /// diff --git a/Tests/Common/Data/SubscriptionDataSourceTests.cs b/Tests/Common/Data/SubscriptionDataSourceTests.cs index 6dbb2f6890c3..5ed9deda7c1a 100644 --- a/Tests/Common/Data/SubscriptionDataSourceTests.cs +++ b/Tests/Common/Data/SubscriptionDataSourceTests.cs @@ -14,7 +14,10 @@ */ using NUnit.Framework; +using Python.Runtime; using QuantConnect.Data; +using System; +using System.Collections.Generic; namespace QuantConnect.Tests.Common.Data { @@ -47,5 +50,44 @@ public void ComparesNotEqualWithDifferentTransportMedium() Assert.IsTrue(one != two); Assert.IsTrue(!one.Equals(two)); } + + [Test] + public void SupportsPythonDictionaryHeaders() + { + using (Py.GIL()) + { + using var headers = new PyDict(); + headers.SetItem("Authorization".ToPython(), "Basic test-token".ToPython()); + headers.SetItem("X-Api-Key".ToPython(), "abc123".ToPython()); + + var dataSource = new SubscriptionDataSource("https://example.com", SubscriptionTransportMedium.RemoteFile, FileFormat.Csv, headers); + CollectionAssert.AreEquivalent(new[] + { + new KeyValuePair("Authorization", "Basic test-token"), + new KeyValuePair("X-Api-Key", "abc123") + }, dataSource.Headers); + } + } + + [Test] + public void SupportsNullPythonDictionaryHeaders() + { + var dataSource = new SubscriptionDataSource("https://example.com", SubscriptionTransportMedium.RemoteFile, FileFormat.Csv, (PyObject)null); + Assert.IsEmpty(dataSource.Headers); + } + + [Test] + public void ThrowsForInvalidPythonHeadersType() + { + using (Py.GIL()) + { + using var invalidHeaders = "invalid-headers".ToPython(); + + var exception = Assert.Throws(() => + new SubscriptionDataSource("https://example.com", SubscriptionTransportMedium.RemoteFile, FileFormat.Csv, invalidHeaders)); + + StringAssert.Contains("ConvertToDictionary cannot be used", exception.Message); + } + } } }