Browse Source

fix: executors type

pull/14/head
QuentinN42 1 year ago
parent
commit
89dd528ac4
Signed by: number42 GPG Key ID: 2CD7D563712B3A50
  1. 4
      .vscode/settings.json
  2. 26
      auto_trading/interfaces.py
  3. 1
      requirements.txt
  4. 32
      tests/test_interfaces.py

4
.vscode/settings.json

@ -6,5 +6,7 @@
"testOnSave.testCommand": "./scripts/test.sh",
"python.testing.pytestArgs": [
"tests"
]
],
"python.formatting.provider": "black",
"editor.formatOnSave": true,
}

26
auto_trading/interfaces.py

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod, abstractproperty
from dataclasses import dataclass, field
from datetime import timedelta, datetime
from pandas import DataFrame
from typing import Dict, List, Optional, Type, Callable
from typing import Dict, List, Optional, Any, Callable
from .errors import OrderException, UnknowOrder, PTFException
@ -26,15 +26,15 @@ class CandlesProperties:
class DataBroker(ABC):
""" Somethink that give you data. """
def __init__(self):
""" Init the class. """
self.logger = logging.getLogger(self.__class__.__name__)
@abstractproperty
def properties(self) -> CandlesProperties:
""" Return the properties of the candles for this broker. """
@abstractproperty
def current_change(self) -> DataFrame:
""" Return the current change for each money. """
@ -57,11 +57,11 @@ class DataBroker(ABC):
class Indicator(ABC):
""" Somethink that give you an insight of the market. """
def __init__(self):
""" Init the class. """
self.logger = logging.getLogger(self.__class__.__name__)
@abstractmethod
def __call__(self, data: DataFrame) -> DataFrame:
"""Return a dataframe of valuation of each stock from the input data.
@ -99,7 +99,8 @@ class Strategy(ABC):
return self.execute(data, indicators_results)
@abstractmethod
def execute(self, data: DataFrame, indicators_results: DataFrame) -> List[Order]:
def execute(self, data: DataFrame,
indicators_results: DataFrame) -> List[Order]:
"""Execute the strategy with the indicators insights.
Args:
@ -116,10 +117,10 @@ class Strategy(ABC):
class PTF(ABC):
""" Somethink that buy or sell stocks."""
executors: Dict[object, Callable[["PTF", Order], None]] = {}
executors: Dict[Any, Callable[[Any, Any], None]] = {}
history: List[Order]
def __init__(self, skip_errors: bool = True, save_errors: bool = True):
""" Init the class.
@ -131,7 +132,7 @@ class PTF(ABC):
self.history = []
self.skip_errors = skip_errors
self.save_errors = save_errors
@abstractproperty
def balance(self) -> float:
""" Return the current total balance. """
@ -149,7 +150,8 @@ class PTF(ABC):
order.successfull = True
except OrderException as e:
if not self.skip_errors:
raise PTFException(f"Got and order exception : {e.message}") from e
raise PTFException(
f"Got and order exception : {e.message}") from e
self.logger.warn("Got an order exception : %s", e.message)
order.successfull = False
if self.save_errors or order.successfull:

1
requirements.txt

@ -5,3 +5,4 @@ tqdm
plotly
pytest
pytest-cov
black

32
tests/test_interfaces.py

@ -5,7 +5,6 @@ from auto_trading.errors import UnknowOrder
from auto_trading.interfaces import PTF
from auto_trading.orders import Short, Long
long_order = Long(datetime.now(), "BTC", 1, 1)
short_order = Short(datetime.now(), "GOLD", 1, 1)
@ -15,9 +14,9 @@ class _TestPTFLong(PTF):
def balance(self):
return 0
def execute_long(self, order: Long):
def execute_long(self, order: Long) -> None:
assert order == long_order
executors = {Long: execute_long}
@ -26,12 +25,12 @@ class _TestPTFLongShort(PTF):
def balance(self):
return 0
def execute_long(self, order: Long):
def execute_long(self, order: Long) -> None:
assert order == long_order
def execute_short(self, order: Long):
def execute_short(self, order: Long) -> None:
assert order == short_order
executors = {Long: execute_long, Short: execute_short}
@ -40,9 +39,9 @@ class _TestPTFShort(PTF):
def balance(self):
return 0
def execute_short(self, order: Long):
def execute_short(self, order: Long) -> None:
assert order == short_order
executors = {Short: execute_short}
@ -52,12 +51,15 @@ class _TestPTFNone(PTF):
return 0
@pytest.mark.parametrize("test_class, must_match", [
(_TestPTFLong, [long_order]),
(_TestPTFShort, [short_order]),
(_TestPTFLongShort, [long_order, short_order]),
(_TestPTFNone, [])
])
@pytest.mark.parametrize(
"test_class, must_match",
[
(_TestPTFLong, [long_order]),
(_TestPTFShort, [short_order]),
(_TestPTFLongShort, [long_order, short_order]),
(_TestPTFNone, []),
],
)
def test_ptf_execution(test_class, must_match):
inst = test_class()
for o in [short_order, long_order]:

Loading…
Cancel
Save